import numpy as np

from diffgro.environments.metaworld.policies.policy import HeuristicPolicy
from diffgro.environments.metaworld.policies.skill import *


############# Standard Policy #############


class BoxV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["lid_pos"] + np.array([0.0, 0.0, 0.02])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveZSkill(target_z + 0.06, grab_effort=-1),
            MoveYSkill(target_y, grab_effort=-1),
            MoveXSkill(target_x, grab_effort=-1),
            MoveZSkill(target_z + 0.04, grab_effort=-1),
            GrabSkill(grab_effort=1),
        ]

        target_x, target_y, target_z = np.array([*o_d["box_pos"], 0.15])
        self.skills += [
            MoveZSkill(target_z + 0.04, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
            MoveXSkill(target_x, grab_effort=1),
            MoveZSkill(target_z + 0.02, grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "gripper": obs[3],
            "lid_pos": obs[4:7],
            "extra_info_1": obs[7:-3],
            "box_pos": obs[-3:-1],
            "extra_info_2": obs[-1],
        }


class PuckV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = np.array([*o_d["puck_pos"][:2], 0.045])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveYSkill(target_y - 0.01, grab_effort=-1),
            MoveXSkill(target_x - 0.1, grab_effort=-1),
            MoveZSkill(target_z, grab_effort=-1),
        ]

        target_x, target_y, _ = o_d["goal_pos"]
        self.skills += [
            MoveXSkill(target_x, grab_effort=-1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "unused_1": obs[3],
            "puck_pos": obs[4:7],
            "unused_2": obs[7:-3],
            "goal_pos": obs[-3:],
        }


class HandleV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["handle_pos"]
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveZSkill(target_z + 0.2, grab_effort=1),
            MoveXSkill(target_x, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
            MoveZSkill(target_z + 0.02, grab_effort=1),
            PushSkill(target_pos + np.array([0, 0, -0.1]), grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "gripper": obs[3],
            "handle_pos": obs[4:7],
            "unused_info": obs[7:-3],
            "goal_pos": obs[-3:],
        }


class DrawerV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["drwr_pos"] + np.array([0.0, 0.0, -0.02])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x, grab_effort=1),
            MoveZSkill(target_z, grab_effort=1),
            PushSkill(target_pos + np.array([0, 0.2, 0]), grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "gripper": obs[3],
            "drwr_pos": obs[4:7],
            "unused_info": obs[7:],
        }


class ButtonV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["button_pos"] + np.array([0.0, 0.0, -0.07])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x),
            MoveZSkill(target_z + 0.1),
            MoveYSkill(target_y + 0.02, atol=0.06, p=3.0),
            PushSkill(target_pos + np.array([0.0, 0.2, 0.1])),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "hand_closed": obs[3],
            "button_pos": obs[4:7],
            "unused_info": obs[7:],
        }


class LeverV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["lever_pos"]
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x - 0.01),
            MoveZSkill(target_z - 0.1),
            MoveYSkill(target_y),
            PushSkill(target_pos + np.array([0.0, 0.2, 0.2])),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "gripper": obs[3],
            "lever_pos": obs[4:7],
            "unused_info": obs[7:-3],
            "goal_pos": obs[-3:],
        }


class StickV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["peg_pos"]
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveXSkill(target_x - 0.05, grab_effort=-1),
            MoveYSkill(target_y, grab_effort=-1, atol=0.02),
            MoveZSkill(target_z + 0.025, grab_effort=-1),
            GrabSkill(grab_effort=1),
        ]

        target_pos = np.array([-0.5, o_d["goal_pos"][1], 0.16])
        target_x, target_y, target_z = target_pos
        self.skills += [
            MoveZSkill(target_z, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
            MoveXSkill(target_x, grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "gripper_distance_apart": obs[3],
            "peg_pos": obs[4:7],
            "peg_rot": obs[7:11],
            "goal_pos": obs[-3:],
            "unused_info_curr_obs": obs[11:18],
            "_prev_obs": obs[18:36],
        }


class DoorV2Policy(HeuristicPolicy):
    def _initialize_skill(self, o_d):
        target_pos = o_d["door_pos"] + np.array([-0.05, 0, 0])
        target_x, target_y, target_z = target_pos
        self.skills = [
            MoveZSkill(target_z + 0.12, grab_effort=1),
            MoveXSkill(target_x - 0.01, grab_effort=1, atol=0.002),
            MoveYSkill(target_y + 0.04, grab_effort=1, atol=0.002),
            MoveZSkill(target_z + 0.025, grab_effort=1),
            PullSkill(target_pos + np.array([-0.5, -0.4, 0]), grab_effort=1),
        ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "gripper": obs[3],
            "door_pos": obs[4:7],
            "unused_info": obs[7:],
        }


class ReturnV2Policy(HeuristicPolicy):
    def __init__(self, x_first, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.x_first = x_first

    def _initialize_skill(self, o_d):
        target_pos = o_d["init_pos"]
        target_x, target_y, target_z = target_pos

        if self.x_first:
            self.skills = [
                MoveXSkill(target_x, grab_effort=-1),
                MoveYSkill(target_y, grab_effort=-1),
                MoveZSkill(target_z, grab_effort=-1),
            ]
        else:
            self.skills = [
                MoveYSkill(target_y, grab_effort=-1),
                MoveXSkill(target_x, grab_effort=-1),
                MoveZSkill(target_z, grab_effort=-1),
            ]

    @staticmethod
    def _parse_obs(obs):
        return {
            "hand_pos": obs[:3],
            "unused": obs[3:-3],
            "init_pos": obs[-3:],
        }


metaworld_complex_policy_dict = {
    "box": BoxV2Policy,
    "puck": PuckV2Policy,
    "handle": HandleV2Policy,
    "button": ButtonV2Policy,
    "drawer": DrawerV2Policy,
    "lever": LeverV2Policy,
    "stick": StickV2Policy,
    "door": DoorV2Policy,
}

############# Variant Policy #############


class BoxVariantV2Policy(BoxV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant


class PuckVariantV2Policy(PuckV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant


class HandleVariantV2Policy(HandleV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = o_d["handle_pos"]
        target_x, target_y, target_z = target_pos
        p = [2, 3.5, 5][self.variant["goal_resistance"]]
        self.skills = [
            MoveZSkill(target_z + 0.2, grab_effort=1),
            MoveXSkill(target_x, grab_effort=1),
            MoveYSkill(target_y, grab_effort=1),
            MoveZSkill(target_z + 0.02, grab_effort=1),
            PushSkill(target_pos + np.array([0, 0, -0.1]), p=p, grab_effort=1),
        ]


class DrawerVariantV2Policy(DrawerV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = o_d["drwr_pos"] + np.array([0.0, 0.0, -0.02])
        target_x, target_y, target_z = target_pos
        p = [2, 3.5, 5][self.variant["goal_resistance"]]
        self.skills = [
            MoveXSkill(target_x, grab_effort=1),
            MoveZSkill(target_z, grab_effort=1),
            PushSkill(target_pos + np.array([0, 0.2, 0]), p=p, grab_effort=1),
        ]


class ButtonVariantV2Policy(ButtonV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = o_d["button_pos"] + np.array([0.0, 0.0, -0.07])
        target_x, target_y, target_z = target_pos
        p = [2, 3.5, 5][self.variant["goal_resistance"]]
        self.skills = [
            MoveXSkill(target_x),
            MoveZSkill(target_z + 0.1),
            MoveYSkill(target_y + 0.02, atol=0.06, p=3.0),
            PushSkill(target_pos + np.array([0.0, 0.2, 0.1]), p=p),
        ]


class LeverVariantV2Policy(LeverV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = o_d["lever_pos"]
        target_x, target_y, target_z = target_pos
        p = [2, 3.5, 5][self.variant["goal_resistance"]]
        self.skills = [
            MoveXSkill(target_x),
            MoveZSkill(target_z - 0.1),
            MoveYSkill(target_y),
            PushSkill(target_pos + np.array([0.0, 0.2, 0.2]), p=p),
        ]


class StickVariantV2Policy(StickV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant


class DoorVariantV2Policy(DoorV2Policy):
    def __init__(self, variant):
        super().__init__()
        self.variant = variant

    def _initialize_skill(self, o_d):
        target_pos = o_d["door_pos"] + np.array([-0.05, 0, 0])
        target_x, target_y, target_z = target_pos
        p = [2, 3.5, 5][self.variant["goal_resistance"]]
        self.skills = [
            MoveZSkill(target_z + 0.12, grab_effort=1),
            MoveXSkill(target_x + 0.01, p=5.0, grab_effort=1, atol=0.002),
            MoveYSkill(target_y + 0.03, p=5.0, grab_effort=1, atol=0.002),
            MoveZSkill(target_z + 0.03, grab_effort=1),
            PullSkill(target_pos + np.array([-0.5, -0.4, 0]), p=p * 2, grab_effort=1),
        ]


metaworld_complex_variant_policy_dict = {
    # "box": BoxVariantV2Policy,
    "puck": PuckVariantV2Policy,
    # "handle": HandleVariantV2Policy,
    "button": ButtonVariantV2Policy,
    "drawer": DrawerVariantV2Policy,
    # "lever": LeverVariantV2Policy,
    "stick": StickVariantV2Policy,
    # "door": DoorVariantV2Policy,
}


############# Meta Policy #############


class MetaWorldComplexPolicy:
    def __init__(self, env, variant=None):
        super().__init__()
        self.env = env

        if variant is None:
            self.policies = [
                metaworld_complex_policy_dict[skill]() for skill in env.skill_list
            ]
        else:
            variant = self._organize_variant_by_task(variant)
            self.policies = [
                metaworld_complex_variant_policy_dict[task](variant[task])
                for task in env.skill_list
            ]

        self.policies += reversed(
            [
                ReturnV2Policy(x_first=task in ["puck", "stick"])
                for task in env.skill_list
            ]
        )

        self.previous_mode = 0
        self.return_mode = 0

    def get_action(self, obs, *args, **kwargs):
        task = self.env.skill_list[self.env.mode]
        if task == "puck":
            flag = 0
        elif task == "drawer":
            flag = 1
        elif task == "button":
            flag = 2
        elif task == "stick":
            flag = 3
        else:
            raise NotImplementedError

        task_obs = obs[4 + flag * 17 : 4 + (flag + 1) * 17]  # 17

        hand_obs = obs[:4]  # 4
        curr_obs = task_obs[:14]
        goal_obs = task_obs[-3:]

        p_obs = np.concatenate([hand_obs, curr_obs, hand_obs, curr_obs, goal_obs])

        # Return to initial state
        if self.previous_mode != self.env.mode:
            if self.policies[-1].is_done:
                self.previous_mode = self.env.mode
                self.policies.pop()
            else:
                p_obs[-3:] = self.env.hand_init_pos_save
                return self.policies[-1].get_action(p_obs, *args, **kwargs)

        policy = self.policies[self.env.mode]
        return policy.get_action(p_obs, *args, **kwargs)

    def reset(self):
        for policy in self.policy:
            policy.reset()

    def _organize_variant_by_task(self, variant):
        variant_by_task = {}

        task_list = self.env.skill_list
        for task in task_list:
            task_variant = {}
            if task in variant:
                task_variant = variant[task]

            for k, v in variant.items():
                if type(v) is dict and task in v:
                    task_variant[k] = v[task]

            variant_by_task[task] = task_variant

        return variant_by_task
